''' 
Main class for hosting the chatbot LLM and generating responses.
Configures LLM settings, model used, tokenizer and base instructions.
Handles prompt building, response cleaning, and streaming output based on user input.
Provides endpoints for health checks and basic API info.
'''

from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig, TextStreamer
import torch
import re
import os
import sys
import json
import yaml
import asyncio
from pathlib import Path
from typing import Dict, Any, Optional, AsyncGenerator, List, Optional
import threading
from queue import Queue

# Pydantic models
# Chatbot
class InferenceRequest(BaseModel):
    question: str
    context: str
    vector_info: Optional[List[str]] = []
    database_info: Optional[dict] = {}

class Response(BaseModel):
    answer_question: str
    model_info: Dict[str, Any]

# Roadmap
class PastCourse(BaseModel):
    code: str
    name: str
    passed: bool

class NextCourse(BaseModel):
    code: str
    name: str
    entry_requirements: str
    corequisites: str
    restrictions: Optional[str] = None

class RoadmapEvaluation(BaseModel):
    degree_name: str
    degree_code: str
    degree_notes: str
    passed_courses: Optional[List[PastCourse]] = []
    failed_courses: Optional[List[PastCourse]] = []
    next_courses: Optional[List[NextCourse]] = []

# Base instructions for the chatbot
base_instructions = """
    You are a knowledgeable university chatbot assistant that helps students navigate their academic journey. 
    Your role is to provide accurate, helpful, and timely information to support student success.

    IMPORTANT INFORMATION ABOUT THE DATABASE:
    • We only have data for the Commerce and Science faculty.
    • If asked about another faculty, tell the student: "If this response is not sufficent, use UCTs main page to find additional advice.".
    • Answer questions about departments to the best of your ability.

    RESPONSE STYLE:
    • Use a friendly, professional tone - be approachable but maintain credibility
    • Write in a conversational style appropriate for university students
    • Be direct and get straight to answering their question
    • Don't use greetings unless the student greets you first
    • Never use email formatting, signatures, or formal letter closings
    • Keep responses concise but thorough - provide necessary context without being verbose

    ACCURACY AND RELIABILITY:
    • Only provide information that is explicitly stated in the information provided
    • Never invent, guess, or make up course codes, requirements, deadlines, policies, or procedures
    • If specific information is missing or unclear, explicitly state this limitation

    HANDLING DIFFERENT QUESTION TYPES:
    
    DEGREE-SPECIFIC COURSE QUESTIONS:
    • **IMPORTANT**: If the question is asking about courses for a specific degree, program, major, or academic pathway (e.g., "What courses do I need for Computer Science degree?", 
        "If I am doing a business science degree specialising in computer science, what first year courses will I do?", 
        Respond with this exact message:
        "For detailed information about courses required for specific degrees, programs, and majors, please visit our Handbook page. You can access this by:

        1. Clicking on "Academic Info" in the navigation bar
        2. Selecting "Handbooks" from the dropdown menu

        The Handbook page contains comprehensive course lists, degree requirements, and academic pathways for all programs. 
        This will give you the most accurate and up-to-date information about what courses you need to take for your specific degree and year of study."
    
    WHEN INFORMATION IS LIMITED:
    • Specify what additional details might be needed
    • Suggest where students can find complete, up-to-date information
    
    FORMATTING:
    • Use bullet points or numbered lists for multi-step processes
    • Make contact information easy to find and use

    AVOID:
    ✗ Greeting the student e.g.: "Hi there," if they asked a question.
    ✗ Formal greetings like "Dear Student" or closings like "Best regards"
    ✗ Making up specific details not provided in your information sources  
    ✗ Giving vague responses when specific information is needed
    ✗ Using overly academic or bureaucratic language
"""

# Model configuration
MODEL_PATH = "/scratch/dhoward/vicuna"
MAX_TOKENS = 1000
USE_GPU = torch.cuda.is_available()

# Performance tuning parameters
USE_4BIT_QUANTIZATION = False 
USE_8BIT_QUANTIZATION = True
BATCH_SIZE = 1

# Model loading parameters 
MODEL_SETTINGS = {
    'low_cpu_mem_usage': True,
    'torch_dtype': torch.float16, # float 16 used for 4bit quant
    'device_map': 'auto',
    'max_memory': {
        0: '22GiB',  # GPU 0
        1: '22GiB',  # GPU 1
        'cpu': '8GiB'
    },
    'trust_remote_code': False,
    'use_fast': True,
    'use_cache': False
}

class CustomTextStreamer(TextStreamer):
    '''
    - Used for console output during model generation
    - Handles proper token spacing and real-time printing to stdout
    - Used in the main model loading section: streamer = CustomTextStreamer(tokenizer, skip_prompt=True)
    '''
    def __init__(self, tokenizer, skip_prompt=True, **decode_kwargs):
        super().__init__(tokenizer, skip_prompt, **decode_kwargs)
        self.generated_text = ""
        self.new_tokens = []
        self.token_buffer = []  # Buffer to accumulate tokens
        self.printed_length = 0  # Track how much we've already printed
        
    def put(self, value):
        """Called for each new token"""
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("TextStreamer only supports batch size 1")
        elif len(value.shape) > 1:
            value = value[0]

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return

        # Add token to buffer
        self.token_buffer.extend(value.tolist() if hasattr(value, 'tolist') else [value.item()])
        
        # Decode the entire buffer to get proper spacing
        full_text = self.tokenizer.decode(self.token_buffer, **self.decode_kwargs)
        
        # Only print the new part (what we haven't printed yet)
        if len(full_text) > self.printed_length:
            new_text = full_text[self.printed_length:]
            print(new_text, end='', flush=True)
            self.printed_length = len(full_text)
            self.generated_text = full_text
    
    def end(self):
        """Called when generation is complete"""
        super().end()
        print()  # Add newline at the end
        
    def get_generated_text(self):
        """Get the complete generated text"""
        return self.generated_text
    
    def reset(self):
        """Reset the streamer for new generation"""
        self.generated_text = ""
        self.new_tokens = []
        self.token_buffer = []
        self.printed_length = 0

class SSETextStreamer(CustomTextStreamer):
    """Custom streamer that yields chunks for SSE"""
    def __init__(self, tokenizer, skip_prompt=True, **decode_kwargs):
        super().__init__(tokenizer, skip_prompt, **decode_kwargs)
        self.chunk_queue = Queue()
        self.is_complete = False
        self.generation_started = False
        self._lock = threading.Lock()
        
    def put(self, value):
        """Called for each new token"""
        if not self.generation_started:
            self.generation_started = True
            
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("TextStreamer only supports batch size 1")
        elif len(value.shape) > 1:
            value = value[0]

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return

        with self._lock:
            # Add token to buffer
            self.token_buffer.extend(value.tolist() if hasattr(value, 'tolist') else [value.item()])
            
            # Decode the entire buffer to get proper spacing
            full_text = self.tokenizer.decode(self.token_buffer, **self.decode_kwargs)
            
            # Only emit the new part (what we haven't emitted yet)
            if len(full_text) > self.printed_length:
                new_text = full_text[self.printed_length:]
                
                # Put the new chunk in the queue
                self.chunk_queue.put({
                    "type": "chunk",
                    "content": new_text,
                    "full_text": full_text
                })
                    
                self.printed_length = len(full_text)
                self.generated_text = full_text
    
    def end(self):
        """Called when generation is complete"""
        super().end()
        with self._lock:
            self.is_complete = True
            self.chunk_queue.put({"type": "complete", "content": "", "full_text": self.generated_text})

    def get_chunk(self, timeout=None):
        """Get the next chunk from the queue"""
        try:
            return self.chunk_queue.get(timeout=timeout)
        except:
            return None    

def get_model_settings(model_name: str) -> Dict[str, Any]:
    """Get model-specific settings, similar to text-generation-webui's approach"""
    settings_file = Path(f"{MODEL_PATH}/config.json")
    
    default_settings = {
        'top_p': 0.8,
        'top_k': 40,
        'repetition_penalty': 1.05,
        'max_new_tokens': MAX_TOKENS,
        'temperature': 0.3,  
        'do_sample': True, # False means 0 temperature
        'pad_token_id': None,  # Will be set after tokenizer loads
        'eos_token_id': None   # Will be set after tokenizer loads
    }
    return default_settings

# Load model and tokenizer
try:
    if not os.path.exists(MODEL_PATH):
        raise FileNotFoundError(f"Model path '{MODEL_PATH}' does not exist.")
    model_settings = get_model_settings(MODEL_PATH)
    # Configure quantization 
    quantization_config = None
    if USE_4BIT_QUANTIZATION:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=False,
            llm_int8_enable_fp32_cpu_offload=True,
            llm_int8_threshold=6.0
        )
        print("4-bit quantization enabled")
    elif USE_8BIT_QUANTIZATION:
        print("8-bit quantization enabled")
        quantization_config = BitsAndBytesConfig(
            load_in_8bit = True, 
            llm_int8_enable_fp32_cpu_offload=True,
            llm_int8_threshold=6.0
        )
    else: 
        print("No quantization enabled. Loading full prevision model...")
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained( # using standard huggingface tokenizer
        MODEL_PATH,
        use_fast=MODEL_SETTINGS.get('use_fast', True),
        trust_remote_code=MODEL_SETTINGS.get('trust_remote_code', False)
    )   
    # Set pad token if it doesn't exist
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print(f"Loading model...")
    model_kwargs = {
        'torch_dtype': MODEL_SETTINGS['torch_dtype'],
        'device_map': MODEL_SETTINGS['device_map'],
        'max_memory': MODEL_SETTINGS['max_memory'],
        'low_cpu_mem_usage': MODEL_SETTINGS['low_cpu_mem_usage'],
        'trust_remote_code': MODEL_SETTINGS['trust_remote_code']
    }
    if quantization_config is not None:
        model_kwargs['quantization_config'] = quantization_config
    model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, **model_kwargs)
    model.eval()
    torch.backends.cudnn.benchmark = True
    if USE_GPU and torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3  # Convert to GB
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
    # Update model settings with tokenizer info
    model_settings['pad_token_id'] = tokenizer.pad_token_id
    model_settings['eos_token_id'] = tokenizer.eos_token_id
    # Create pipeline with optimized settings
    print("Creating pipeline...")
    device_str = "distributed across CPU & GPU" if MODEL_SETTINGS['device_map'] == "auto" else MODEL_SETTINGS['device_map']
    print(f"Model loaded successfully using {device_str}")
    # print(f"Model settings: {model_settings}") # for testing settings loaded
    streamer = CustomTextStreamer(tokenizer, skip_prompt=True)
except Exception as e:
    print(f"Error initializing model: {str(e)}")
    sys.exit(1)

# Root endpoint for basic API information
def read_root():
    return {
        "api": "Academic Database SQL Generation API",
        "version": "1.0.0",
        "status": "running",
        "endpoints": {
            "POST /generate-sql": "Convert natural language to SQL for academic database",
            "GET /health": "Check API health",
            "GET /schema": "View the database schema"
        }
    }

# Health check endpoint
def health_check():
    return {
        "status": "healthy",
        "model_loaded": True,
        "device": "cuda" if USE_GPU else "cpu"
    }

def build_prompt(question: str, context_info: str = "", database_info: str = "", vector_info: str = "") -> str:
    '''
    Build the final LLM prompt with separate sections for:
    - database schema/info
    - FAISS-retrieved vector info
    - prior context (previous Q/A pairs)
    - the new question
    '''
    prompt = []

    if database_info and database_info.strip():
        prompt.append(f"Database Information:\n{database_info}\n")
    if vector_info and vector_info.strip():
        prompt.append(f"Additional Information from Student Advisors:\n{vector_info}\n") 
    '''if context_info: 
        prompt.append(f"Conversation history of current chat with student:\n{context_info}")''' # removed due to context constraints
    if not prompt:
        information = "No specific instructions have been provided for this question."
    else:
        information = "\n\n".join(prompt)
    
    prompt = f"""{base_instructions}

Information available to answer this question:
{information}

Student question which you must answer: {question}

Please provide a helpful, accurate response based on the information provided above. 
If the available information doesn't fully address the question, clearly indicate what's missing and suggest appropriate next steps.
"""
    print(f"Prompt:{prompt}") # for debugging
    return prompt

def build_prompt_roadmap(request: RoadmapEvaluation) -> str:
    '''
    Build the final LLM prompt for roadmap evaluation.
    '''
    # Extract only course codes for passed and failed courses
    passed_codes = [course.code for course in (request.passed_courses or [])]
    failed_codes = [course.code for course in (request.failed_courses or [])]
    # Format the course codes as simple lists
    passed_courses_str = ", ".join(passed_codes) if passed_codes else "None"
    failed_courses_str = ", ".join(failed_codes) if failed_codes else "None"
    # Format next courses in a more readable way
    next_courses_formatted = []
    for course in (request.next_courses or []):
        course_info = f"Course: {course.code} - {course.name}"
        if course.entry_requirements:
            course_info += f"\n  Entry Requirements: {course.entry_requirements}"
        if course.corequisites:
            course_info += f"\n  Corequisites: {course.corequisites}"
        if course.restrictions:
            course_info += f"\n  Restrictions: {course.restrictions}"
        next_courses_formatted.append(course_info)
    
    next_courses_str = "\n\n".join(next_courses_formatted) if next_courses_formatted else "None"

    prompt = f"""
YOUR TASK:
Using the degree information and the student's academic history, assess each of the next course options. For each course, provide:
1. An assessment of whether the student meets the entry requirements for their planned courses.
2. Identification of any prerequisite gaps or potential risks with academic path.
3. Suggestions for alternative courses if needed.
Please respond in a clear, structured format, using bullet points or numbered lists where appropriate. 

Please evaluate this academic roadmap for a {request.degree_name} ({request.degree_code}) student:
DEGREE INFORMATION:
{request.degree_notes}

ACADEMIC HISTORY:
Passed Courses: {passed_courses_str}
Failed Courses: {failed_courses_str}

NEXT COURSE OPTIONS:
{next_courses_str}
"""
    print(f"Roadmap Prompt:{prompt}") # for debugging
    return prompt

def clean_response_start(text: str) -> str:
    """Remove unwanted prefixes with more comprehensive patterns"""
    if not text:
        return text
    
    # List of common prefixes to remove (case insensitive)
    prefixes_to_remove = [
        r'^Answer:\s*',
        r'^Response:\s*', 
        r'^Reply:\s*',
        r'^Your response:\s*',
        r'^Chatbot response:\s*',
        r'^Bot response:\s*',
        r'^AI response:\s*',
        r'^Assistant response:\s*',
        r'^My response:\s*',
        r'^Here is my response:\s*',
        r'^Here\'s my response:\s*',
        r'^Student answer:\s*',
        r'^Question answer:\s*',
        r'^---\s*',
        r'^GPT Answer:\s*',
        r'^Please write your response below:\s*',
    ]
    
    cleaned = text.lstrip()
    for pattern in prefixes_to_remove:
        cleaned = re.sub(pattern, '', cleaned, flags=re.IGNORECASE)
        if cleaned != text:  # If found a match, break early
            break
    # Fallback: if nothing matched but there's still a colon in first ~30 characters
    if cleaned == text and ':' in text[:30]:
        # Find first colon and remove everything before it
        colon_pos = text.find(':')
        cleaned = text[colon_pos+1:].strip()
    return cleaned

def clean_response_end(text: str) -> str:
    """Remove unwanted suffixes from the end of responses"""
    if not text:
        return text
    # Remove </s> and similar tokens at the end
    cleaned = re.sub(r'\s*</s>\s*$', '', text)
    cleaned = re.sub(r'\s*<\|endoftext\|>\s*$', '', cleaned)
    cleaned = re.sub(r'\s*<\|im_end\|>\s*$', '', cleaned)
    return cleaned.strip()

async def answer_question_streaming(request: InferenceRequest) -> AsyncGenerator[dict, None]:
    """
    Streaming version that yields dictionaries (easier to work with in FastAPI)
    """
    try:
        sse_streamer = SSETextStreamer(tokenizer, skip_prompt=True)
        # Use the provided parameters or fall back to query attributes
        vector_info_str = "\n".join(request.vector_info or []) 
        results = request.database_info.get("results") if request.database_info else None
        database_info_str = "No database results found for this question." if results is None else json.dumps(results, indent=2)
        context_str = request.context or ""
        prompt = build_prompt(
            question=request.question,
            context_info=context_str,
            database_info=database_info_str,
            vector_info=vector_info_str,
        )
        inputs = tokenizer(prompt, return_tensors="pt")
        if USE_GPU and torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        
        def generate():
            with torch.no_grad():
                model.generate(
                    **inputs,
                    max_new_tokens=MAX_TOKENS,
                    do_sample=model_settings['do_sample'],
                    temperature=model_settings['temperature'],
                    top_p=model_settings['top_p'],
                    top_k=model_settings['top_k'],
                    repetition_penalty=model_settings['repetition_penalty'],
                    pad_token_id=model_settings['pad_token_id'],
                    eos_token_id=model_settings['eos_token_id'],
                    streamer=sse_streamer,
                )
        generation_thread = threading.Thread(target=generate)
        generation_thread.start()
        first_chunk = True
        while True:
            chunk_data = sse_streamer.get_chunk(timeout=0.1)
            if chunk_data is None:
                if not generation_thread.is_alive() and sse_streamer.is_complete:
                    break
                await asyncio.sleep(0.01)
                continue
            if chunk_data["type"] == "chunk":
                content = chunk_data["content"]
                full_text = chunk_data["full_text"]
                # Clean the beginning only on the first chunk
                if first_chunk:
                    content = clean_response_start(content)
                    full_text = clean_response_start(full_text)
                    first_chunk = False
                yield {
                    "type": "text",
                    "content": content,
                    "full_text": full_text,
                    "status": "generating"
                }
            elif chunk_data["type"] == "complete":
                # Clean the end of the final response
                final_text = clean_response_end(chunk_data["full_text"])
                # Also clean the start in case it wasn't caught earlier
                final_text = clean_response_start(final_text)
                yield {
                    "type": "complete",
                    "content": "",
                    "full_text": final_text,
                    "status": "complete"
                }
                break          
        generation_thread.join()
    except Exception as e:
        yield {
            "type": "error",
            "content": f"Error generating response: {str(e)}",
            "status": "error"
        }
        
async def road_map_evaluation(request: RoadmapEvaluation) -> AsyncGenerator[dict,None]:
    """
    Evaluate the students roadmap choices and give them an idea of what to choose
    Returns streaming output. Not currently used in the frontend.
    Future implementation
    """
    try:
        sse_streamer = SSETextStreamer(tokenizer, skip_prompt=True)
        # Use the provided parameters or fall back to query attributes
        prompt = build_prompt_roadmap(request)
        print(prompt)
        inputs = tokenizer(prompt, return_tensors="pt")
        if USE_GPU and torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        def generate():
            with torch.no_grad():
                model.generate(
                    **inputs,
                    max_new_tokens=MAX_TOKENS,
                    do_sample=model_settings['do_sample'],
                    temperature=model_settings['temperature'],
                    top_p=model_settings['top_p'],
                    top_k=model_settings['top_k'],
                    repetition_penalty=model_settings['repetition_penalty'],
                    pad_token_id=model_settings['pad_token_id'],
                    eos_token_id=model_settings['eos_token_id'],
                    streamer=sse_streamer,
                )
        generation_thread = threading.Thread(target=generate)
        generation_thread.start()
        first_chunk = True
        while True:
            chunk_data = sse_streamer.get_chunk(timeout=0.1)
            if chunk_data is None:
                if not generation_thread.is_alive() and sse_streamer.is_complete:
                    break
                await asyncio.sleep(0.01)
                continue
            if chunk_data["type"] == "chunk":
                content = chunk_data["content"]
                full_text = chunk_data["full_text"]
                # Clean the beginning only on the first chunk
                if first_chunk:
                    content = clean_response_start(content)
                    full_text = clean_response_start(full_text)
                    first_chunk = False
                yield {
                    "type": "text",
                    "content": content,
                    "full_text": full_text,
                    "status": "generating"
                }
            elif chunk_data["type"] == "complete":
                # Clean the end of the final response
                final_text = clean_response_end(chunk_data["full_text"])
                # Also clean the start in case it wasn't caught earlier
                final_text = clean_response_start(final_text)
                print(f"Final text \n{final_text}")
                yield {
                    "type": "complete",
                    "content": "",
                    "full_text": final_text,
                    "status": "complete"
                }
                break
        generation_thread.join()
    except Exception as e:
        yield {
            "type": "error",
            "content": f"Error generating response: {str(e)}",
            "status": "error"
        }

async def road_map_evaluation_simple(request: RoadmapEvaluation) -> str:
    """
    Non-streaming version that returns the complete text response
    Current version used in the frontend.
    """
    try:
        prompt = build_prompt_roadmap(request)
        inputs = tokenizer(prompt, return_tensors="pt")
        if USE_GPU and torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_TOKENS,
                do_sample=model_settings['do_sample'],
                temperature=model_settings['temperature'],
                top_p=model_settings['top_p'],
                top_k=model_settings['top_k'],
                repetition_penalty=model_settings['repetition_penalty'],
                pad_token_id=model_settings['pad_token_id'],
                eos_token_id=model_settings['eos_token_id'],
            )
        # Decode the generated text
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Remove the original prompt from the generated text
        prompt_length = len(prompt)
        response_text = generated_text[prompt_length:].strip()
        # Clean the response
        response_text = clean_response_start(response_text)
        response_text = clean_response_end(response_text)
        return response_text
    except Exception as e:
        return f"Error generating roadmap evaluation: {str(e)}"
